#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon May 24 09:12:44 2021

Useful functions for Hg0 comparisons

@author: arifeinberg
"""
from os import path
import xarray as xr
import pandas as pd
import numpy as np
from calendar import monthrange

def get_mod_data(mod_data, obs_lat, obs_lon, obs_mon=None):
    if isinstance(obs_lat, float):
        model_dd_f = get_mod_data_single(mod_data, obs_lat, obs_lon, obs_mon)
        return model_dd_f
    model_dd_f = pd.Series(np.zeros(len(obs_lat)))
    model_ann_mean = annual_avg(mod_data) #take annual mean to not calculate at each point in loop
    
    if obs_mon is not None: # account for different months given
        for i in range(len(obs_lat)):
            obs_mon_i = obs_mon[i]
            
            if not isinstance(obs_mon_i, str): # not multiple months
                if np.isnan(obs_mon_i):
                    model_dd_f[i] = model_ann_mean.sel(lon=obs_lon[i], lat=obs_lat[i], method = 'nearest')
                    
            else: # only certain months are in flux
                if not "," in obs_mon_i: # only one month included, can be converted to integer
                    model_dd_f[i] = mod_data.isel(time=int(obs_mon_i)-1).sel(
                        lon=obs_lon[i], lat=obs_lat[i], method = 'nearest')
                    
                else: # have to take average of multiple months
                    temp = obs_mon_i.split(",") # split into list of strings
                    # convert each element as integers
                    li_mon = []
                    for im in temp:
                        li_mon.append(int(im) - 1) # minus 1 for 0 subscripting in python
                    temp_mean = mod_data.isel(time=li_mon).mean("time") # take mean of desired months   
                    model_dd_f[i] = temp_mean.sel(lon=obs_lon[i], lat=obs_lat[i], method = 'nearest') # find coords
    else: # use annual averages of model
        for i in range(len(obs_lat)):
            model_dd_f[i] = model_ann_mean.sel(lon=obs_lon[i], lat=obs_lat[i], method = 'nearest')

    return model_dd_f

def get_mod_data_single(mod_data, obs_lat, obs_lon, obs_mon=None): # when not a list of sites, simplify function
    if obs_mon is not None: # account for single months given
        if isinstance(obs_mon, int): # single month
            model_dd_f = mod_data.isel(time=obs_mon-1).sel(
                        lon=obs_lon, lat=obs_lat, method = 'nearest')
    else: # use annual average of model
            model_dd_f = mod_data.sel(
                    lon=obs_lon, lat=obs_lat, method = 'nearest').mean("time")
    return model_dd_f

def get_chc_obs(source): # get observations for Chacaltaya
    chc_raw= pd.read_csv(source,delimiter = ";")
    #The first 35 rows of the csv file is descriptive information
    chc = chc_raw.iloc[35:]

    #Reset the index
    chc.reset_index(inplace= True, drop =True)

    #rename columns appropriately
    chc.columns =['tstamp','value']

    #drop NA values in the dataframe
    chc = chc[chc['value'].notna()]
    chc.reset_index(inplace= True, drop =True)
    chc
    #change concentrations to floating point since they are currently strings
    conc=[]
    for index, dat in enumerate(chc['value']):
        num = dat.split(",")
        if len(num)!=2:
            conc.append(float(dat))
        else:
            conc.append(float(num[0]+"."+num[1]))
    chc['value'] = conc

    # Convert that column into a datetime datatype
    chc['tstamp'] = pd.to_datetime(chc['tstamp'],dayfirst=True,utc=True) 

    chc.index = chc['tstamp'] # Set the datetime column as the index
   
    return chc

def get_data(station):#BAR,CAL,CST, MAN,NIK, SIS
    #convert string to uppercase just incase the case is different 
    station = str(station.upper())
    #create dictionary for file path
    sites = {'BAR':'Bariloche/BAR.csv', 'CAL':'Calhau/CAL.csv','CST':'Celestun/CST.csv','MAN':'Manaus/MAN.csv','NIK':'Niew Nickerie/NIK.csv','SIS':'Sisal/SIS.csv' }
    
    #check if given name is in the list of sites
    if station.upper() in ['BAR','CAL','CST', 'MAN','NIK','SIS']:
        #create function to upload the data from svante
        def get_site(station):
            source='../../../d1/tzd/GMOS_Observations/GMOS_Observations/'+sites[station]
            site= pd.read_csv(source)
            return site
        
  #get required columns from the raw data
        time= get_site(station)['tstamp']
        value= get_site(station)['value']
        unit = get_site(station)['uom']
        
        time = pd.to_datetime(time,dayfirst=True,utc=True) 

        return pd.concat([time,value,unit], axis=1)# return a data frame with the timestamp, value and units of measurement
    else:# error message if user enter the wrong name that is not any of the sites 
        return "Please try again, use either one of the following 'BAR','CAL','CST', 'MAN','NIK', 'SIS' "
    
def get_obs_ts(site_str):
    #check if it is Chacaltaya:
    if site_str == 'CHC' : # different file formate for Chacaltaya
        #Location of data in Svante
        path14 ='../../../d1/tzd/GMOS_Observations/2014/L1_TGM_CHC_2014.csv'
        path15 ='../../../d1/tzd/GMOS_Observations/2015/L1_TGM_CHC_2015.csv'
        # combine data from the two years
        obs_Hg0 =  pd.concat([get_chc_obs(path14), get_chc_obs(path15)], axis=0)
    else:
        obs_Hg0 = get_data(site_str)
    
    return obs_Hg0

def get_model_ts(site_str, run_str): # get model time series
    # first check if site files exist
    pn1 = '../GEOS-Chem_runs/run' + run_str + '/OutputDir/' # path of simulation
    fn = 'GEOSChem.SpeciesConc' + site_str + '.alltime_d.nc4' # time series at site
    
    bool_exists = path.exists(pn1 + fn) # boolean to check if exists
    if not bool_exists :
        print('Error with filename or site string incorrect')
        return
    
    # continue with extracting model data
    ds1 = xr.open_dataset(pn1 + fn) # BASE simulation
    
    if site_str == 'CHC' : # CHC altitude is close to GEOS-Chem level 20
        Hg0 = ds1.SpeciesConc_Hg0.isel(lev = 19).squeeze() # subset for level
    elif site_str == 'TIT' : # TIT altitude is close to GEOS-Chem level 15
        Hg0 = ds1.SpeciesConc_Hg0.isel(lev = 14).squeeze() # subset for level
    else : # assume that all other sites are at the surface
        Hg0 = ds1.SpeciesConc_Hg0.isel(lev = 0).squeeze()
        
    # convert units to ng m^-3
    R = 8.314462 # m^3 Pa K^-1 mol ^-1
    MW_Hg = 200.59 # g mol^-1
    ng_g = 1e9 # ng/g
    
    stdpressure = 101325 # Pascals
    stdtemp = 273.15 # Kelvins
    
    unit_conv = stdpressure / R / stdtemp * MW_Hg * ng_g # converter from vmr to ng m^-3
    
    Hg0 = Hg0 * unit_conv
        
    # return time series
    return Hg0

def ds_sel_yr (ds, varname, Year):
    """ If a year is given, then subset load for that year. Otherwise load all data into variable

    Parameters
    ----------
    ds : xarray dataset
        Dataset of simulation to extract data from
    varname : string
        Name of parameter to extract
    Year : int
        Subset of year(s) to analyze model data  
        
    """
    
    if Year is not None: # take average over subset of years
        var_yr = ds[varname].sel(time=ds.time.dt.year.isin(Year))
    else: # use all years
        var_yr = ds[varname]

    return var_yr

def annual_avg (var_to_avg):
    """ Take annual average from monthly data, accounting for the difference in day number
    Parameters
    ----------
    var_to_avg : xarray dataArray
        variable to average
    """
    time_v = var_to_avg.time
    
    #first check that the data is in monthly time resolution
    diff = time_v.dt.month[1]- time_v.dt.month[0]
    
    if diff > 1: # more than monthly time difference
        print("Simulation needs to be in monthly-averaged time resolution. This data is super-monthly resolution")
        sys.exit(1)
    elif diff<1: # less than monthly time difference
        print("Simulation needs to be in monthly-averaged time resolution. This data is sub-monthly resolution")
        sys.exit(1)
    else:
        n_time = len(time_v) # number of timesteps
        days_in_month = np.zeros(n_time)
        for i in range(n_time):
            # return number of days in month, accounting for leap year
            days_in_month[i] = monthrange(int(time_v.dt.year[i]),int(time_v.dt.month[i]))[1] 
        
        #for weighted average, need this variable as an xarray
        days_in_month_xr = xr.DataArray(
            data=days_in_month,
            dims=["time"],
            coords=dict(
                time=time_v
                ),
            attrs=dict(
                description="Days in month of simulation",
                ),
            )
        # adding weights to time dimension    
        wgted_var = var_to_avg.weighted(days_in_month_xr)
        # taking mean weighted by month length
        wgted_mean = wgted_var.mean("time")
        
        return wgted_mean
